8940d1
@@ -2,6 +2,9 @@
package org.springframework.batch.item.xml;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
 
 import javax.xml.namespace.QName;
 import javax.xml.stream.XMLEventReader;
@@ -16,9 +19,7 @@
import org.springframework.batch.item.ItemStream;
 import org.springframework.batch.item.ItemStreamException;
 import org.springframework.batch.item.ReaderNotOpenException;
 import org.springframework.batch.item.xml.stax.DefaultFragmentEventReader;
-import org.springframework.batch.item.xml.stax.DefaultTransactionalEventReader;
 import org.springframework.batch.item.xml.stax.FragmentEventReader;
-import org.springframework.batch.item.xml.stax.TransactionalEventReader;
 import org.springframework.beans.factory.InitializingBean;
 import org.springframework.core.io.Resource;
 import org.springframework.dao.DataAccessResourceFailureException;
@@ -26,7 +27,7 @@
import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
 
 /**
- * Input source for reading XML input based on StAX.
+ * Item reader for reading XML input based on StAX.
  * 
  * It extracts fragments from the input XML document which correspond to records
  * for processing. The fragments are wrapped with StartDocument and EndDocument
@@ -42,7 +43,7 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 
 	private FragmentEventReader fragmentReader;
 
-	private TransactionalEventReader txReader;
+	private XMLEventReader eventReader;
 
 	private EventReaderDeserializer eventReaderDeserializer;
 
@@ -59,6 +60,15 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 	private long currentRecordCount = 0;
 
 	private boolean saveState = false;
+	
+	private List buffer = new ArrayList();
+	
+	private Iterator bufferIterator = null;
+	
+	/**
+	 * indicates the reader has been shouldReadBuffer and should read items from buffer
+	 */
+	private boolean shouldReadBuffer = false;
 
 	public StaxEventItemReader() {
 		setName(ClassUtils.getShortName(StaxEventItemReader.class));
@@ -74,15 +84,29 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 		if (!initialized) {
 			throw new ReaderNotOpenException("Reader must be open before it can be read.");
 		}
-		Object item = null;
 
 		currentRecordCount++;
+		
+		// read from buffer after rollback
+		if (shouldReadBuffer) {
+			if (bufferIterator.hasNext()) {
+				return bufferIterator.next();
+			} else {
+				// buffer is exhausted, continue reading from file
+				shouldReadBuffer = false;
+				bufferIterator = null;
+			}
+		}
+		
+		Object item = null;
+
 		if (moveCursorToNextFragment(fragmentReader)) {
 			fragmentReader.markStartFragment();
 			item = eventReaderDeserializer.deserializeFragment(fragmentReader);
 			fragmentReader.markFragmentProcessed();
 		}
 
+		buffer.add(item);
 		if (item == null) {
 			currentRecordCount--;
 		}
@@ -92,6 +116,7 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 	public void close(ExecutionContext executionContext) {
 		initialized = false;
 		currentRecordCount = 0;
+		clearBuffer();
 		try {
 			if (fragmentReader != null) {
 				fragmentReader.close();
@@ -117,9 +142,9 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 
 		try {
 			inputStream = resource.getInputStream();
-			txReader = new DefaultTransactionalEventReader(XMLInputFactory.newInstance().createXMLEventReader(
-					inputStream));
-			fragmentReader = new DefaultFragmentEventReader(txReader);
+			eventReader = XMLInputFactory.newInstance().createXMLEventReader(
+					inputStream);
+			fragmentReader = new DefaultFragmentEventReader(eventReader);
 		}
 		catch (XMLStreamException xse) {
 			throw new DataAccessResourceFailureException("Unable to create XML reader", xse);
@@ -134,8 +159,9 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 			int REASONABLE_ADHOC_COMMIT_FREQUENCY = 100;
 			while (currentRecordCount <= restoredRecordCount) {
 				currentRecordCount++;
+				
 				if (currentRecordCount % REASONABLE_ADHOC_COMMIT_FREQUENCY == 0) {
-					txReader.onCommit(); // reset the history buffer
+					mark(); // clear the history buffer
 				}
 				if (!fragmentReader.hasNext()) {
 					throw new ItemStreamException("Restore point must be before end of input");
@@ -143,7 +169,7 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 				fragmentReader.next();
 				moveCursorToNextFragment(fragmentReader);
 			}
-			mark(); // reset the history buffer
+			mark(); // clear the history buffer
 		}
 	}
 
@@ -236,7 +262,8 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 	 */
 	public void mark() {
 		lastCommitPointRecordCount = currentRecordCount;
-		txReader.onCommit();
+		clearBuffer();
+		shouldReadBuffer = false;
 	}
 
 	/*
@@ -246,7 +273,8 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 	 */
 	public void reset() {
 		currentRecordCount = lastCommitPointRecordCount;
-		txReader.onRollback();
+		shouldReadBuffer = true;
+		bufferIterator = buffer.listIterator();
 		fragmentReader.reset();
 	}
 
@@ -260,4 +288,12 @@
public class StaxEventItemReader extends ExecutionContextUserSupport implements
 	public void setSaveState(boolean saveState) {
 		this.saveState = saveState;
 	}
+	
+	/**
+	 * Clear the buffer and release the iterator.
+	 */
+	private void clearBuffer() {
+		buffer.clear();
+		bufferIterator = null;
+	}
 }
